import sys

from PyQt5 import QtWidgets
from PyQt5.QtWidgets import QApplication, QMessageBox
from PyQt5 import QtGui
from PyQt5.QtGui import QIcon
import akshare as ak
import baostock as bs
from datetime import datetime, timedelta
import pandas as pd
import pyqtgraph as pg

from chart import ChartWidget, CandleItem, VolumeItem, MaItem
from pts.constant import Exchange, Interval
from pts.utility import import_data_from_csv
from pts.setting import SETTINGS, CHANSETTINGS
from pts.ui.download_config import DownloadConfig
from pts.constant import FREQS, FREQS_WINDOW, FREQS_EM, FREQS_BS


class CandleChartDialog(QtWidgets.QDialog):
    """
    K线图表对话框
    """

    def __init__(self):
        """初始化"""
        super().__init__()

        # 设置窗口的图标
        self.setWindowIcon(QIcon(r'images\YuLan.png'))

        # 是否接受过数据
        self.updated = False

        # 日期时间-索引字典
        self.dt_ix_map = {}
        # 索引-K线字典
        self.ix_bar_map = {}

        # 最高价
        self.high_price = 0
        # 最低价
        self.low_price = 0
        # 价格范围
        self.price_range = 0

        # 图元列表
        self.items = []

        self.curInterval = -1
        self.KLineData = r'Data/KLineData/'     # 保存K线数据文件的目录
        self.vt_symbol = ""
        self.button_style_normal = "QPushButton{background-color:white;color:black;border - radius:10px;border:1px groove gray;border-style: outset;}"
        self.button_style_pressed = "QPushButton{background-color:blue;color:white;border - radius:10px;border:1px groove gray;border-style: outset;}"
        self.interval_button = []

        self.init_ui()
        self.interval_button[0].clicked.connect(self.on_pushButton_clicked0)
        self.interval_button[1].clicked.connect(self.on_pushButton_clicked1)
        self.interval_button[2].clicked.connect(self.on_pushButton_clicked2)
        self.interval_button[3].clicked.connect(self.on_pushButton_clicked3)


    def init_ui(self):
        """"""
        self.setWindowTitle("K线图表")
        self.resize(1400, 800)

        # 创建界面上部的K线图表对象
        self.chart = ChartWidget()
        self.chart.add_plot("candle", hide_x_axis=True)
        self.chart.add_plot("volume", maximum_height=200)
        self.chart.add_item(CandleItem, "candle", "candle")
        # self.chart.add_item(MaItem, "ma", "candle")
        self.chart.add_item(VolumeItem, "volume", "volume")
        self.chart.add_cursor()

        # self.update_trades()

        # 创建界面中部的提示信息
        label1 = QtWidgets.QLabel("本功能使用本地数据生成K线图表。特定股票的本地数据可能不存在，也可能不是最新的。如有需要，请按“刷新行情数据”按钮下载最新的行情数据。")
        label2 = QtWidgets.QLabel("本功能的缠论分析需要调用普吸金缠论的二次开发工具，请参考文档《普吸金缠论大全》。")
        label3 = QtWidgets.QLabel("本功能使用了vn.py中K线图表功能的代码，相关代码的详细分析请参考《Python量化交易从入门到实战》一书及其附配文档《vn.py源代码深入分析》。")

        # 创建界面底部的操作区
        hbox1 = QtWidgets.QHBoxLayout()
        pushButton0 = QtWidgets.QPushButton("1分钟")
        hbox1.addWidget(pushButton0)
        self.interval_button.append(pushButton0)
        pushButton1 = QtWidgets.QPushButton("5分钟")
        hbox1.addWidget(pushButton1)
        self.interval_button.append(pushButton1)
        pushButton2 = QtWidgets.QPushButton("30分钟")
        hbox1.addWidget(pushButton2)
        self.interval_button.append(pushButton2)
        pushButton3 = QtWidgets.QPushButton("1日")
        hbox1.addWidget(pushButton3)
        self.interval_button.append(pushButton3)
        spacerItem1 = QtWidgets.QSpacerItem(40, 20, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Minimum)
        hbox1.addItem(spacerItem1)
        self.btnRefresh = QtWidgets.QPushButton("刷新行情数据")
        self.btnRefresh.setObjectName("btnRefresh")
        hbox1.addWidget(self.btnRefresh)
        self.btnRefresh.clicked.connect(self.on_btnRefresh_clicked)

        # Set layout
        vbox = QtWidgets.QVBoxLayout()
        vbox.addWidget(self.chart)
        vbox.addWidget(label1)
        vbox.addWidget(label2)
        vbox.addWidget(label3)
        vbox.addLayout(hbox1)
        self.setLayout(vbox)

    def set_vt_symbol(self, vt_symbol):
        self.vt_symbol = vt_symbol

        self.set_interval_button_style(3)

    def set_interval_button_style(self, idx: int):
        self.curInterval = idx

        for i in range(4):
            if i == idx:
                self.interval_button[i].setStyleSheet(self.button_style_pressed)
            else:
                self.interval_button[i].setStyleSheet(self.button_style_normal)

        exchange = Exchange.SSE
        interval = [Interval.MINUTE, Interval.MINUTE5, Interval.MINUTE30, Interval.DAILY]
        req = {
            'symbol': self.vt_symbol,
            'exchange': exchange,
            'interval': interval[idx],
            'KLineData': self.KLineData
        }

        # 从CSV文件读取数据
        history = import_data_from_csv(req)
        self.clear_data()
        self.update_history(history)


    def on_pushButton_clicked0(self):
        self.set_interval_button_style(0)

    def on_pushButton_clicked1(self):
        self.set_interval_button_style(1)

    def on_pushButton_clicked2(self):
        self.set_interval_button_style(2)

    def on_pushButton_clicked3(self):
        self.set_interval_button_style(3)

    def on_btnRefresh_clicked(self):
        '''“刷新数据”按钮'''
        # 打开数据库配置对话框
        if SETTINGS["EveryTime"] != 0:
            w = DownloadConfig()
            reply = w.exec_()
            if reply != QtWidgets.QDialog.Accepted:
                return

        if SETTINGS["DataSource"] == 2:
            self.reloadHisFromEM()
        elif SETTINGS["DataSource"] == 3:
            self.reloadHisFromSina()
        else:
            self.reloadHisFromBS()

        self.set_interval_button_style(self.curInterval)

    def update_history(self, history: list):
        """
        接受一个K线数据列表
        """
        self.updated = True
        self.chart.update_history(history)

        for ix, bar in enumerate(history):
            self.ix_bar_map[ix] = bar
            self.dt_ix_map[bar.datetime] = ix

            if not self.high_price:
                self.high_price = bar.high_price
                self.low_price = bar.low_price
            else:
                self.high_price = max(self.high_price, bar.high_price)
                self.low_price = min(self.low_price, bar.low_price)

        self.price_range = self.high_price - self.low_price

    def update_trades(self):
        """
        接受一个交易数据列表
        """
        '''本系统不使用交易数据'''
        candle_plot: pg.PlotItem = self.chart.get_plot("candle")
        #trade_scatter: pg.ScatterPlotItem = pg.ScatterPlotItem()
        #self.items.append(trade_scatter)
        #candle_plot.addItem(trade_scatter)

        # Trade text
        volume = 1
        text_color: QtGui.QColor = "yellow"
        open_text: pg.TextItem = pg.TextItem(f"[1]", color=text_color, anchor=(0.5, 0.5))

        open_text.setPos(996, 10)

        self.items.append(open_text)

        candle_plot.addItem(open_text)

    def clear_data(self):
        """"""
        self.updated = False

        candle_plot = self.chart.get_plot("candle")
        for item in self.items:
            candle_plot.removeItem(item)
        self.items.clear()

        self.chart.clear_all()

        self.dt_ix_map.clear()
        self.ix_bar_map.clear()

    def is_updated(self):
        """"""
        return self.updated

    def getStartDate(self):
        dateDict = {}
        for freq in FREQS:
            dtNow = datetime.today()
            if FREQS_WINDOW[freq] == Interval.DAILY:
                dateDict[freq] = dtNow + timedelta(-1500)
            elif FREQS_WINDOW[freq] == Interval.MINUTE30:
                dateDict[freq] = dtNow + timedelta(-210)
            elif FREQS_WINDOW[freq] == Interval.MINUTE5:
                dateDict[freq] = dtNow + timedelta(-30)
            else:
                dateDict[freq] = dtNow + timedelta(-7)
        return dateDict

    def reloadHisFromEM(self):
        dateDict = self.getStartDate()

        for freq in FREQS:
            stock_data = None
            try:
                if FREQS_WINDOW[freq] == Interval.DAILY:
                    stock_data = ak.stock_zh_a_hist(symbol=self.vt_symbol, period="daily", start_date=dateDict[freq].strftime('%Y%m%d'), end_date=datetime.today().strftime('%Y%m%d'), adjust="qfq")
                    if not stock_data is None:
                        # 取需要的字段
                        stock_data = stock_data[
                            [
                                "日期",
                                "开盘",
                                "最高",
                                "最低",
                                "收盘",
                                "成交量",
                                "成交额",
                            ]
                        ]
                else:
                    stock_data = ak.stock_zh_a_hist_min_em(symbol=self.vt_symbol, start_date=dateDict[freq].strftime('%Y-%m-%d %H:%M:%S'), end_date=datetime.today().strftime('%Y-%m-%d %H:%M:%S'), period=FREQS_EM[freq], adjust='qfq')
                    if not stock_data is None:
                        # 取需要的字段
                        stock_data = stock_data[
                            [
                                "时间",
                                "开盘",
                                "最高",
                                "最低",
                                "收盘",
                                "成交量",
                                "成交额",
                            ]
                        ]

                # 修改字段标题
                if not stock_data is None:
                    stock_data.columns = [
                        "datetime",
                        "open",
                        "high",
                        "low",
                        "close",
                        "volume",
                        "amount",
                    ]

                    # 将取到的数据存入CSV文件
                    file_name = self.KLineData + self.vt_symbol + '_' + FREQS_WINDOW[freq].value + '.csv'
                    stock_data.to_csv(file_name, index=False)

            except TypeError:
                QMessageBox.information(self, '提示信息', '由于服务器的原因，无法下载数据。')
            except BaseException as ex:
                QMessageBox.information(self, '提示信息', '数据下载失败：{}。'.format(ex.args))
            #if stock_data is None:
            #    return
        QMessageBox.information(self, '提示信息', '从东方财富网下载数据完成。')

    def reloadHisFromSina(self):
        dateDict = self.getStartDate()

        vt_symbol = 'sz'
        if self.vt_symbol[0] == '6':
            vt_symbol = 'sh'
        vt_symbol = vt_symbol + self.vt_symbol

        for freq in FREQS:
            stock_data = None
            try:
                if FREQS_WINDOW[freq] == Interval.DAILY:
                    stock_data = ak.stock_zh_a_daily(symbol=vt_symbol, start_date=dateDict[freq].strftime('%Y%m%d'), end_date=datetime.today().strftime('%Y%m%d'), adjust="qfq")
                    if not stock_data is None:
                        # 取需要的字段
                        stock_data = stock_data[
                            [
                                "date",
                                "open",
                                "high",
                                "low",
                                "close",
                                "volume",
                            ]
                        ]
                        # 新浪分钟数据没有成交额字段，补充上
                        stock_data['amount'] = 0
                else:
                    stock_data = ak.stock_zh_a_minute(symbol=vt_symbol, period=FREQS_EM[freq], adjust="qfq")
                    if not stock_data is None:
                        # 新浪分钟数据不能指定起止日期，可能会取到很多，我们只保留最后的1000条
                        stock_data = stock_data.iloc[-1000:]
                        # 取需要的字段
                        stock_data = stock_data[
                            [
                                "day",
                                "open",
                                "high",
                                "low",
                                "close",
                                "volume",
                            ]
                        ]
                        # 新浪分钟数据没有成交额字段，补充上
                        stock_data['amount'] = 0

                # 修改字段标题
                if not stock_data is None:
                    stock_data.columns = [
                        "datetime",
                        "open",
                        "high",
                        "low",
                        "close",
                        "volume",
                        "amount",
                    ]

                    # 将取到的数据存入CSV文件
                    file_name = self.KLineData + self.vt_symbol + '_' + FREQS_WINDOW[freq].value + '.csv'
                    stock_data.to_csv(file_name, index=False)

            except TypeError:
                QMessageBox.information(self, '提示信息', '由于服务器的原因，无法下载数据。')
            except BaseException as ex:
                QMessageBox.information(self, '提示信息', '数据下载失败：{}。'.format(ex.args))
            #if stock_data is None:
            #    return
        QMessageBox.information(self, '提示信息', '从新浪下载数据完成。')

    def reloadHisFromBS(self):
        lg = bs.login()
        if lg.error_code != '0':
            QMessageBox.information(self, '提示信息', '登录BaoStock失败。')
            return

        dateDict = self.getStartDate()

        vt_symbol = 'sz.'
        if self.vt_symbol[0] == '6':
            vt_symbol = 'sh.'
        vt_symbol = vt_symbol + self.vt_symbol

        for freq in FREQS:
            # BS不支持1分钟数据
            if FREQS_WINDOW[freq] == Interval.MINUTE:
                continue

            stock_data = None
            try:
                if FREQS_WINDOW[freq] == Interval.DAILY:
                    rs = bs.query_history_k_data_plus(vt_symbol,
                                                      "date,open,high,low,close,volume,amount",
                                                      start_date=dateDict[freq].strftime('%Y-%m-%d'), end_date=datetime.today().strftime('%Y-%m-%d'),
                                                      frequency="d", adjustflag="2")

                    #### 打印结果集 ####
                    data_list = []
                    while (rs.error_code == '0') & rs.next():
                        # 获取一条记录，将记录合并在一起
                        data_list.append(rs.get_row_data())
                    stock_data = pd.DataFrame(data_list, columns=rs.fields)
                else:
                    rs = bs.query_history_k_data_plus(vt_symbol,
                                                      "time,open,high,low,close,volume,amount",
                                                      start_date=dateDict[freq].strftime('%Y-%m-%d'), end_date=datetime.today().strftime('%Y-%m-%d'),
                                                      frequency=FREQS_BS[freq], adjustflag="2")

                    #### 打印结果集 ####
                    data_list = []
                    while (rs.error_code == '0') & rs.next():
                        # 获取一条记录，将记录合并在一起
                        row = rs.get_row_data()
                        d1 = row[0]
                        row[0] = d1[:4] + '-' + d1[4:6] + '-' + d1[6:8] + ' ' + d1[8:10] + ':' + d1[10:12] + ':00'
                        data_list.append(row)
                    stock_data = pd.DataFrame(data_list, columns=rs.fields)

                # 修改字段标题
                if not stock_data is None:
                    stock_data.columns = [
                        "datetime",
                        "open",
                        "high",
                        "low",
                        "close",
                        "volume",
                        "amount",
                    ]

                    # 将取到的数据存入CSV文件
                    file_name = self.KLineData + self.vt_symbol + '_' + FREQS_WINDOW[freq].value + '.csv'
                    stock_data.to_csv(file_name, index=False)

            except BaseException as ex:
                QMessageBox.information(self, '提示信息', '数据下载失败：{}。'.format(ex.args))

        bs.logout()
        QMessageBox.information(self, '提示信息', '从BaoStock下载数据完成。')


if __name__ == '__main__':
    app = QApplication(sys.argv)
    ui = CandleChartDialog()

    exchange = Exchange.SSE
    KLineData = r'Data/KLineData/'  # 保存K线数据文件的目录
    req = {
        'symbol': '002241',
        'exchange': exchange,
        'interval': Interval.DAILY,
        'KLineData': KLineData
    }
    # 从CSV文件读取数据
    history = import_data_from_csv(req)

    ui.update_history(history)
    ui.update_trades()

    ui.show()
    sys.exit(app.exec_())
